#
import argparse
from typing import Dict
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

class FashionApp(object):
    device = 'cuda:0'
    def __init__(self):
        self.name = 'fashion'

    @staticmethod
    def startup(params:Dict = {}) -> None:
        print(f'Pytorch示例工程 v0.0.1')
        # FashionApp.train_main(params=params)
        FashionApp.predict(params=params)


    @staticmethod
    def train_main(params:Dict = {}) -> None:
        # Download training data from open datasets.
        training_data = datasets.FashionMNIST(
            root="data",
            train=True,
            download=True,
            transform=ToTensor(),
        )
        # Download test data from open datasets.
        test_data = datasets.FashionMNIST(
            root="data",
            train=False,
            download=True,
            transform=ToTensor(),
        )
        batch_size = 64
        # Create data loaders.
        train_dataloader = DataLoader(training_data, batch_size=batch_size)
        test_dataloader = DataLoader(test_data, batch_size=batch_size)
        for X, y in test_dataloader:
            print(f"Shape of X [N, C, H, W]: {X.shape}")
            print(f"Shape of y: {y.shape} {y.dtype}")
            break
        print(f"Using {FashionApp.device} device")
        model = NeuralNetwork().to(FashionApp.device)
        print(model)
        loss_fn = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
        epochs = 5
        for t in range(epochs):
            print(f"Epoch {t+1}\n-------------------------------")
            FashionApp.train(train_dataloader, model, loss_fn, optimizer)
            FashionApp.test(test_dataloader, model, loss_fn)
        print("Done!")
        torch.save(model.state_dict(), "model.pth")
        print("Saved PyTorch Model State to model.pth")

    @staticmethod
    def train(dataloader, model, loss_fn, optimizer):
        size = len(dataloader.dataset)
        model.train()
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(FashionApp.device), y.to(FashionApp.device)
            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)
            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if batch % 100 == 0:
                loss, current = loss.item(), (batch + 1) * len(X)
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    @staticmethod
    def test(dataloader, model, loss_fn):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(FashionApp.device), y.to(FashionApp.device)
                pred = model(X)
                test_loss += loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    @staticmethod
    def predict(params:Dict = {}) -> None:
        model = NeuralNetwork().to(FashionApp.device)
        model.load_state_dict(torch.load("model.pth", weights_only=True))
        classes = [
            "T-shirt/top",
            "Trouser",
            "Pullover",
            "Dress",
            "Coat",
            "Sandal",
            "Shirt",
            "Sneaker",
            "Bag",
            "Ankle boot",
        ]
        model.eval()
        # Download test data from open datasets.
        test_data = datasets.FashionMNIST(
            root="data",
            train=False,
            download=True,
            transform=ToTensor(),
        )
        batch_size = 1
        # Create data loaders.
        test_dataloader = DataLoader(test_data, batch_size=batch_size)
        x, y = test_data[0][0], test_data[0][1]
        with torch.no_grad():
            x = x.to(FashionApp.device)
            pred = model(x)
            predicted, actual = classes[pred[0].argmax(0)], classes[y]
            print(f'Predicted: "{predicted}", Actual: "{actual}"')













def main(params:Dict = {}) -> None:
    FashionApp.startup(params=params)

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--run_mode', action='store',
        type=int, default=1, dest='run_mode',
        help='run mode'
    )
    return parser.parse_args()

if '__main__' == __name__:
    args = parse_args()
    params = vars(args)
    main(params=params)